# -*- coding: utf-8 -*-  
'''
初始化sohu_thuc_news数据集

Created on 2021年9月8日
@author: luoyi
'''
import utils.conf as conf
from utils.iexicon import LiteWordsWarehouse
from data.sohu_thuc_news.bert_dataset import BertPreTFRecordDataset
from data.sohu_thuc_news.lda_gsdmm_dataset import LdaGSDmmPreDataset
from data.sohu_thuc_news.tbert_dataset import TBertTFRecordDataset
from models.lda.nets_np import LDA
from models.gsdmm.nets_np import GSDMM


#    写入bert预训练数据集
def write_bert_pre_ds():
    bert_pre_ds_w = BertPreTFRecordDataset(count=100)
    bert_pre_ds_w.write_tfrecord()
    pass


#    写入lda/gsdmm预训练数据集
def write_lda_gsdmm_pre_ds():
    lda_gsdmm_ds = LdaGSDmmPreDataset(count=100)
    lda_gsdmm_ds.write_words_dataset()
    pass


#    写入tbert训练数据集
def write_tbert_ds():
    lda = LDA(K=conf.LDA.get_k(), V=LiteWordsWarehouse.instance().words_count())
    lda.load_weight()
    dmm = GSDMM(K=conf.GSDMM.get_k(), V=LiteWordsWarehouse.instance().words_count())
    dmm.load_weight()
    
    tbert_ds = TBertTFRecordDataset(count=100, lda=lda, dmm=dmm)
    tbert_ds.write_tfrecord(tbert_training_train_count=1000, 
                            tbert_training_val_count=200, 
                            tbert_tfrecord_limit=1024)
    pass


# write_bert_pre_ds()
# write_lda_gsdmm_pre_ds()
write_tbert_ds()